Datatypes in E(3) Neural Networks

using the se3cnn repository

tutorial by: Tess E. Smidt

code by:

DOI

@misc{mario_geiger_2019_3348277,
  author       = {Mario Geiger and
                  Tess Smidt and
                  Wouter Boomsma and
                  Maurice Weiler and
                  Michał Tyszkiewicz and
                  Jes Frellsen and
                  Benjamin K. Miller},
  title        = {mariogeiger/se3cnn: Point cloud support},
  month        = jul,
  year         = 2019,
  doi          = {10.5281/zenodo.3348277},
  url          = {https://doi.org/10.5281/zenodo.3348277}
}

Our data types are geometry and features on that geometry expressed as geometric tensors

Most properties of physical systems are expressed in terms of geometric tensors. Scalars (mass), vectors (velocities, forces, polarizations), matrices (polarizability, moment of inertia) and higher rank tensors are all geometric tensors.

Geometric tensors: Cartesian tensors and Spherical Tensors

Geometric tensors are commonly expressed with Cartesian indicies $(x, y, z)$ -- we will call these Cartesian tensors. However, there is an equally expressive way of representing geometric tensors as spherical tensors.

Whereas for Cartesian tensors the indices can be interpreted as information along $(x, y, z)$, spherical tensors are index by which spherical harmonic they are associated with. These representations can be used interchangable.

We use spherical tensors in our network because our convolutional filters are expressed in terms of spherical harmonics -- more about that later.

Wikipedia has a great overview of spherical harmonics. As a quick recap, the spherical harmonics are the Fourier basis for functions on the unit sphere. They have two indices, most commonly called the "degree" $L$ and "order" $m$ and are commonly parameterized by spherical coordinate angles $\theta$ and $\phi$.

$Y_{l}^{m}(\theta, \phi)$ for complex spherical harmonics or $Y_{lm}(\theta, \phi)$ for real spherical harmonics.

In se3cnn, we use real spherical harmonics. There are $2 L + 1$ functions (indexed by $m$) for each $L$. Functions of degree $L$ have the same frequency. Note, that these frequencies must be integral (or half-integral for $SU(2)$) because of the periodic boundary conditions of the sphere.

Representation Lists in se3cnn

To keep track of which spherical tensor entries correspond to which spherical harmonic, we use representation lists, commonly saved as a variable Rs.

Rs is a list of tuples (mult, L) where mult is the multiplicity (or number of copies) and L is the degree of the spherical harmonic.

For example, the Rs of a single vector is Rs_vec = [(1, 1)] and two vectors Rs_2vec = [(2, 1)]

You will sometimes see an Rs with three integers in the tuple (mult, L, parity), where the first two are the same as before and parity indicates whether that part of the tensor has equal 0 or opposite 1 parity as the spherical harmonic. All odd $L$ spherical harmonics have odd parity (they do change under parity) and all even $L$ spherical harmonics have even parity (they do NOT change under parity).

Spherical Harmonics

First, let's draw the spherical harmonics using the SphericalTensor class defined in spherical.py. This is a handy helper class that I've written for this tutorial so we can quickly manipulate and plot spherical tensors.

In [1]:
import torch 
import numpy as np
import se3cnn.SO3 as SO3
from spherical import SphericalTensor # a small Signal class written for ease of handling Spherical Tensors
import plotly
from plotly.subplots import make_subplots

torch.set_default_dtype(torch.float64)

L_max = 3
rows = L_max + 1
cols = 2 * L_max + 1

specs = [[{'is_3d': True} for i in range(cols)]
         for j in range(rows)]
fig = make_subplots(rows=rows, cols=cols, specs=specs)

for L in range(L_max + 1):
    for m in range(0, 2 * L + 1):
        tensor = torch.zeros(2 * L + 1)
        tensor[m] = 1.0
        sphten = SphericalTensor(tensor, Rs=[(1, L)])
        row, col = L + 1, (L_max - L) + m + 1
        trace = sphten.plot(relu=False, n=60)
        if m != 2 * L_max:
            trace.showscale = False
        fig.add_trace(trace, row=row, col=col)

fig.show()

Spherical harmonics as linear combination of monomials

To understand how the spherical harmonics are grouped, it can be helpful to think of the spherical harmonics as being built from monomials proportional to $x^\alpha y^\beta z^\gamma$ where $L = \alpha + \beta + \gamma$. For $L=0$ there is only 1 spherical harmonic and 1 monomial ($1$), for $L=1$ there are 3 spherical harmonics and 3 monomials $(y, z, x)$, for $L=2$ there are 5 spherical harmonics but 6 monomials $(x^2, y^2, z^2, xy, yz, zx)$.

How do we go from 6 to 5? Well, there's a hidden redundancy in these 6 monomials. $x^2$, $y^2$, and $z^2$ are mixtures of L=0 and L=2 which stems from the fact that $x^2 + y^2 + z^2 = r^2$ which is a scalar. We can calculate how these monomials project onto spherical tensors.

In [2]:
empty = torch.zeros(3, 3)
x2, y2, z2 = empty.clone(), empty.clone(), empty.clone()
x2[0, 0], y2[1, 1], z2[2, 2] = 1, 1, 1 # Create tensor representation of x^2, y^2 and z^2

perm_x2 = x2.clone()[torch.tensor([1, 2, 0]).unsqueeze(1), torch.tensor([1, 2, 0]).unsqueeze(0)]
perm_y2 = y2.clone()[torch.tensor([1, 2, 0]).unsqueeze(1), torch.tensor([1, 2, 0]).unsqueeze(0)]
perm_z2 = z2.clone()[torch.tensor([1, 2, 0]).unsqueeze(1), torch.tensor([1, 2, 0]).unsqueeze(0)]

# Representation lists that we use in `se3cnn` indicate the order of coeffients in a spherical tensor
# A component of a representation list describes the multiplicity and degree (mult, L)
Rs_vec = [(1, 1)] # Representation list of a single vector
Rs_3x3, C = SO3.reduce_tensor_product(Rs_vec, Rs_vec) # Rep vec and Clebsch-Gordon

print("x^2, y^2, and z^2 are mixtures of L=0 and L=2")
print("SH:", "  1      y      z      x      xy     yz     *      zx     %", )
print("x^2", torch.einsum('ijk,ij->k', C, perm_x2).detach().numpy().round(3))
print("y^2", torch.einsum('ijk,ij->k', C, perm_y2).detach().numpy().round(3))
print("z^2", torch.einsum('ijk,ij->k', C, perm_z2).detach().numpy().round(3))
print("* == 2z^2 - x^2 - y^2")
print("% == x^2 - y^2")
x^2, y^2, and z^2 are mixtures of L=0 and L=2
SH:   1      y      z      x      xy     yz     *      zx     %
x^2 [ 0.577 -0.    -0.    -0.    -0.     0.    -0.408 -0.     0.707]
y^2 [ 0.577 -0.     0.     0.     0.     0.    -0.408  0.    -0.707]
z^2 [ 0.577 -0.     0.     0.     0.    -0.     0.816  0.    -0.   ]
* == 2z^2 - x^2 - y^2
% == x^2 - y^2

3x3 Matrix as a Cartesian and Spherical tensor

Geometric tensors rotate predictably under rotation. Let's take the example of a 3 x 3 matrix, a Cartesian tensor of rank 2.

$M_{ij} = \begin{pmatrix} \alpha_{xx} & \alpha_{xy} & \alpha_{xz} \\ \alpha_{yx} & \alpha_{yy} & \alpha_{yz}\\ \alpha_{zx} & \alpha_{zy} & \alpha_{zz} \end{pmatrix}$

where $i$ and $j$ are indexed as $(x, y, z)$.

We can also express this matrix as a spherical harmonic tensor. The way to do this conversion is to recognize that $L=1$ spherical tensor has the same indices as $(x, y, z)$ EXCEPT they are permuted as $(y, z, x)$.

In [3]:
M = torch.randn(3, 3)
# Permute indices to ('y', 'z', 'x') to be compatible with spherical harmonic convention
perm_M = M.clone()[torch.tensor([1, 2, 0]).unsqueeze(1), torch.tensor([1, 2, 0]).unsqueeze(0)]

import matplotlib.pyplot as plt
%matplotlib inline
fig, axes = plt.subplots(1, 2, figsize=(8, 5));
axes[0].matshow(M)
axes[0].set_title('M');
axes[0].get_xaxis().set_visible(False)
axes[0].get_yaxis().set_visible(False)
for i, x in enumerate(["x", "y", "z"]):
    for j, y in enumerate(["x", "y", "z"]):
        axes[0].text(j - 0.2, i + 0.1, x + y, {'color':'white', 'fontsize': 20})
    
im = axes[1].matshow(perm_M)
axes[1].set_title('M permuted for both indices');
axes[1].get_xaxis().set_visible(False)
axes[1].get_yaxis().set_visible(False)
for i, x in enumerate(["y", "z", "x"]):
    for j, y in enumerate(["y", "z", "x"]):
        axes[1].text(j - 0.2, i + 0.1, x + y, {'color':'white', 'fontsize': 20})
        
fig.colorbar(im, ax=axes[:], shrink=0.75);

Rotating Cartesian and Spherical tensors

Our Cartesian matrix can be rotated with a 3D rotation matrix R applied to each Cartesian index.

$R_{ki} R_{lj} M_{ij} = M_{kl}$

As shown above, we can permute our Cartesian indices $(x, y, z)$ into those of L=1 spherical harmonics $(y, z, x)$.

We can even simplify the matrix by combining its two indices into a single index using the Clebsch-Gordon coefficients.

$I_k = C_{ijk} M_{ij}$

where $C_{ijk}$ are the Clebsch-Gordon tensor. See Griffiths -- Introduction to Quantum Mechanics, Ch. 4 for more details.

We can then rotate this index using Wigner D-matrices, rotation matrices for the irreducible basis.

$I_{i} = D_{ij} I_j$

We can then convert back to the 3x3 matrix format to see that these rotations are indeed equivalent.

In [4]:
# random rotation Euler angles alpha, beta, gamma
angles = torch.rand(3) * torch.tensor([np.pi, 2 * np.pi, np.pi])
rot = SO3.rot(*angles)

rotated_M = torch.einsum('ki,ij,lj->kl', rot, M, rot)

Rs_vec = [(1, 1)] # Representation list of a single vector
Rs_3x3, C = SO3.reduce_tensor_product(Rs_vec, Rs_vec)
print("Single index representation of 3x3 matrix:", Rs_3x3)
print("Shape of Clebsch-Gordon tensor:", C.shape)

# Wigner D matrix -- rotation matrix for irreducible representations
wignerD = SO3.direct_sum(*[SO3.irr_repr(l, *angles) for mul, l, parity in Rs_3x3 for _ in range(mul)])
print("Shape of Wigner-D matrix:", wignerD.shape)

# Convert matrix to representation vector
I = torch.einsum('ijk,ij->k', C, perm_M)
# Rotate representation vector
rotated_I = torch.einsum('ij,j->i', wignerD, I)

# And we can convert this back to our original format to compare
rotated_perm_M = torch.einsum('ijk,k->ij', C, rotated_I)
rotated_M_prime = rotated_perm_M.clone()
rotated_M_prime[torch.tensor([1, 2, 0]).unsqueeze(1), torch.tensor([1, 2, 0]).unsqueeze(0)] = rotated_perm_M.clone()
Single index representation of 3x3 matrix: [(1, 0, 0), (1, 1, 0), (1, 2, 0)]
Shape of Clebsch-Gordon tensor: torch.Size([3, 3, 9])
Shape of Wigner-D matrix: torch.Size([9, 9])
In [5]:
# Visualize M and rotated_M
import matplotlib.pyplot as plt
%matplotlib inline
fig, axes = plt.subplots(1, 3, figsize=(12, 6));
axes[0].matshow(M)
axes[0].set_title('M');
axes[0].get_xaxis().set_visible(False)
axes[0].get_yaxis().set_visible(False)
    
axes[1].matshow(rotated_M)
axes[1].set_title('Rotated M from Cartesian');
axes[1].get_xaxis().set_visible(False)
axes[1].get_yaxis().set_visible(False)

im = axes[2].matshow(rotated_M_prime)
axes[2].set_title("Rotated M from Spherical");
axes[2].get_xaxis().set_visible(False)
axes[2].get_yaxis().set_visible(False)
        
fig.colorbar(im, ax=axes[:], shrink=0.75);

We can interpret our spherical harmonic tensors as components of traditional geometric tensor or as geometry itself.

In [6]:
rows, cols = 1, 1
specs = [[{'is_3d': True} for i in range(cols)]
         for j in range(rows)]
fig = make_subplots(rows=rows, cols=cols, specs=specs)

L_max = 6
Rs = [(1, L) for L in range(L_max + 1)]
sum_Ls = sum(2 * L + 1 for mult, L in Rs) 

# Random spherical tensor up to L_Max
rand_sph_tensor = torch.randn(sum_Ls)

sphten = SphericalTensor(rand_sph_tensor, Rs)

trace = sphten.plot(relu=False, n=60)
fig.add_trace(trace, row=1, col=1)
fig.show()
In [7]:
# Projection of tetrahedron on origin

rows, cols = 1, 1
specs = [[{'is_3d': True} for i in range(cols)]
         for j in range(rows)]
fig = make_subplots(rows=rows, cols=cols, specs=specs)

L_max = 6
tetra_coords = torch.tensor( # The easiest way to construct a tetrahedron is using opposite corners of a box
    [[0., 0., 0.], [1., 1., 0.], [1., 0., 1.], [0., 1., 1.]]
)
tetra_coords -= tetra_coords.mean(-2)

fig = make_subplots(rows=rows, cols=cols, specs=specs)

sphten = SphericalTensor.from_geometry(tetra_coords, L_max)

trace = sphten.plot(relu=False, n=60)
fig.add_trace(trace, row=1, col=1)
fig.show()
In [ ]: